package com.lambkit.plugin.activerecord;

import com.jfinal.plugin.activerecord.*;
import com.lambkit.db.*;
import com.lambkit.db.dialect.IDialect;
import com.lambkit.db.hutool.RowDataListHandler;

import javax.sql.DataSource;
import java.sql.*;
import java.util.ArrayList;
import java.util.List;

/**
 * @author yangyong(孤竹行)
 */
public class JFinalDb extends JFinalDbOpt<RowData, PageData<RowData>> implements IDb {

    private DataSource dataSource;

    public JFinalDb(String configName)  {
        super(Db.use(configName), (IDialect) Db.use(configName).getConfig().getDialect());
        this.dataSource = getDb().getConfig().getDataSource();
    }

    public SqlPara toSqlPara(Sql sql) {
        SqlPara sqlPara = new SqlPara();
        sqlPara.setSql(sql.getSql());
        List<Object> paras = sql.getParaList();
        for(int i=0; i<paras.size(); i++) {
            sqlPara.addPara(paras.get(i));
        }
        return sqlPara;
    }

    public Sql toSql(SqlPara sqlPara) {
        Sql sql = new Sql();
        sql.setSql(sqlPara.getSql());
        Object[] paras = sqlPara.getPara();
        sql.setPara(paras);
        return sql;
    }

    @Override
    public Connection getConnection() throws SQLException {
        return getDb().getConfig().getConnection();
    }

    @Override
    public void closeConnection(Connection connection) {
        getDb().getConfig().close(connection);
    }

    @Override
    public DataSource getDataSource() {
        return dataSource;
    }

    @Override
    public RowData newRowData() {
        return new RowData();
    }

    @Override
    public PageData<RowData> newPageData() {
        return new PageData<RowData>();
    }

    @Override
    public <T extends RowModel<T>> IRowDao<T> dao(Class<T> clazz) {
        return new JFinalModelDao<>(this, clazz);
    }

    @Override
    public <T extends IRowData, P extends IPageData<T>> IDbOpt by(Class<T> rowClazz, Class<P> pageClazz) {
        return new JFinalDbService(this, rowClazz, pageClazz);
    }

    @Override
    public PageData<RowData> paginate(Integer pageNumber, Integer pageSize, Sql sql) {
        String[] sqls = PageSqlKit.parsePageSql(sql.getSql());
        String select = sqls[0];
        String sqlExceptSelect = sqls[1];
        return doPaginate(pageNumber, pageSize, (Boolean) null, select, sqlExceptSelect, sql.getPara());
    }

    @Override
    public List<RowData> find(String sql, Object... paras) {
        Connection conn = null;
        try {
            conn = getConnection();
            return find(getConfig(), conn, sql, paras);
        } catch (Exception e) {
            throw new ActiveRecordException(e);
        } finally {
            closeConnection(conn);
        }
    }

    @Override
    public RowData findFirst(String sql, Object... paras) {
        List<RowData> result = this.find(sql, paras);
        return result.size() > 0 ? result.get(0) : null;
    }

    protected PageData<RowData> doPaginate(int pageNumber, int pageSize, Boolean isGroupBySql, String select, String sqlExceptSelect, Object... paras) {
        Connection conn = null;
        try {
            conn = getConfig().getConnection();
            String totalRowSql = getConfig().getDialect().forPaginateTotalRow(select, sqlExceptSelect, null);
            StringBuilder findSql = new StringBuilder();
            findSql.append(select).append(' ').append(sqlExceptSelect);
            return doPaginateByFullSql(getConfig(), conn, pageNumber, pageSize, isGroupBySql, totalRowSql, findSql, paras);
        } catch (Exception e) {
            throw new ActiveRecordException(e);
        } finally {
            getConfig().close(conn);
        }
    }

    protected PageData<RowData> doPaginateByFullSql(Config config, Connection conn, int pageNumber, int pageSize, Boolean isGroupBySql, String totalRowSql, StringBuilder findSql, Object... paras) throws SQLException, InstantiationException, IllegalAccessException {
        if (pageNumber < 1 || pageSize < 1) {
            throw new ActiveRecordException("pageNumber and pageSize must more than 0");
        }
        if (config.getDialect().isTakeOverDbPaginate()) {
            //return getConfig().getDialect().takeOverDbPaginate(conn, pageNumber, pageSize, isGroupBySql, totalRowSql, findSql, paras);
            throw new RuntimeException("You should implements this method in " + config.getDialect().getClass().getName());
        }

        List result = query(config, conn, totalRowSql, paras);
        int size = result.size();
        if (isGroupBySql == null) {
            isGroupBySql = size > 1;
        }

        long totalRow;
        if (isGroupBySql) {
            totalRow = size;
        } else {
            totalRow = (size > 0) ? ((Number)result.get(0)).longValue() : 0;
        }
        if (totalRow == 0) {
            return new PageData<RowData>(new ArrayList<RowData>(0), pageNumber, pageSize, 0, 0);
        }

        int totalPage = (int) (totalRow / pageSize);
        if (totalRow % pageSize != 0) {
            totalPage++;
        }

        if (pageNumber > totalPage) {
            return new PageData<RowData>(new ArrayList<RowData>(0), pageNumber, pageSize, totalPage, (int)totalRow);
        }

        // --------
        String sql = config.getDialect().forPaginate(pageNumber, pageSize, findSql);
        List<RowData> list = find(config, conn, sql, paras);
        return new PageData<RowData>(list, pageNumber, pageSize, totalPage, (int)totalRow);
    }

    protected List<RowData> find(Config config, Connection conn, String sql, Object... paras) throws SQLException, InstantiationException, IllegalAccessException {
        try (PreparedStatement pst = conn.prepareStatement(sql)) {
            config.getDialect().fillStatement(pst, paras);
            ResultSet rs = pst.executeQuery();
            List<RowData> result = RowDataListHandler.create().handle(rs);
            close(rs);
            return result;
        }
    }

    protected <T> List<T> query(Config config, Connection conn, String sql, Object... paras) throws SQLException {
        List result = new ArrayList();
        try (PreparedStatement pst = conn.prepareStatement(sql)) {
            config.getDialect().fillStatement(pst, paras);
            ResultSet rs = pst.executeQuery();
            int colAmount = rs.getMetaData().getColumnCount();
            if (colAmount > 1) {
                while (rs.next()) {
                    Object[] temp = new Object[colAmount];
                    for (int i=0; i<colAmount; i++) {
                        temp[i] = rs.getObject(i + 1);
                    }
                    result.add(temp);
                }
            }
            else if(colAmount == 1) {
                while (rs.next()) {
                    result.add(rs.getObject(1));
                }
            }
            close(rs);
            return result;
        }
    }

    void close(ResultSet rs, Statement st) throws SQLException {
        if (rs != null) {rs.close();}
        if (st != null) {st.close();}
    }

    void close(ResultSet rs) throws SQLException {
        if (rs != null) {rs.close();}
    }

    void close(Statement st) throws SQLException {
        if (st != null) {st.close();}
    }
}
